/-
Copyright (c) 2014 Parikshit Khanna. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Parikshit Khanna, Jeremy Avigad, Leonardo de Moura, Floris van Doorn, Mario Carneiro
-/
module

prelude
public import Init.Data.List.Pairwise
public import Init.Data.List.Zip

public section

/-!
# Lemmas about `List.range` and `List.zipIdx`

Most of the results are deferred to `Data.Init.List.Nat.Range`, where more results about
natural arithmetic are available.
-/

set_option linter.listVariables true -- Enforce naming conventions for `List`/`Array`/`Vector` variables.
set_option linter.indexVariables true -- Enforce naming conventions for index variables.

namespace List

open Nat

/-! ## Ranges and enumeration -/

/-! ### range' -/

@[simp, grind =] theorem length_range' {s step} : ∀ {n : Nat}, length (range' s n step) = n
  | 0 => rfl
  | _ + 1 => congrArg succ length_range'

@[simp, grind =] theorem range'_eq_nil_iff : range' s n step = [] ↔ n = 0 := by
  rw [← length_eq_zero_iff, length_range']

theorem range'_ne_nil_iff (s : Nat) {n step : Nat} : range' s n step ≠ [] ↔ n ≠ 0 := by
  simp

theorem range'_eq_cons_iff : range' s n step = a :: xs ↔ s = a ∧ 0 < n ∧ xs = range' (a + step) (n - 1) step := by
  induction n generalizing s with
  | zero => simp
  | succ n ih =>
    simp only [range'_succ]
    simp only [cons.injEq, and_congr_right_iff]
    rintro rfl
    simp [eq_comm]

@[simp, grind =] theorem tail_range' : (range' s n step).tail = range' (s + step) (n - 1) step := by
  cases n with
  | zero => simp
  | succ n => simp [range'_succ]

@[simp, grind =] theorem range'_inj : range' s n = range' s' n' ↔ n = n' ∧ (n = 0 ∨ s = s') := by
  constructor
  · intro h
    have h' := congrArg List.length h
    simp at h'
    subst h'
    cases n with
    | zero => simp
    | succ n =>
      simp only [range'_succ] at h
      simp_all
  · rintro ⟨rfl, rfl | rfl⟩ <;> simp

theorem mem_range' : ∀ {n}, m ∈ range' s n step ↔ ∃ i < n, m = s + step * i
  | 0 => by simp [range', Nat.not_lt_zero]
  | n + 1 => by
    have h (i) : i ≤ n ↔ i = 0 ∨ ∃ j, i = succ j ∧ j < n := by
      cases i <;> simp [Nat.succ_le_iff, Nat.succ_inj]
    simp [range', mem_range', Nat.lt_succ_iff, h]; simp only [← exists_and_right, and_assoc]
    rw [exists_comm]; simp [Nat.mul_succ, Nat.add_assoc, Nat.add_comm]

theorem getElem?_range' {s step} :
    ∀ {i n : Nat}, i < n → (range' s n step)[i]? = some (s + step * i)
  | 0, n + 1, _ => by simp [range'_succ]
  | m + 1, n + 1, h => by
    simp only [range'_succ, getElem?_cons_succ]
    exact (getElem?_range' (s := s + step) (by exact succ_lt_succ_iff.mp h)).trans <| by
      simp [Nat.mul_succ, Nat.add_assoc, Nat.add_comm]

@[simp, grind =] theorem getElem_range' {n m step} {i} (H : i < (range' n m step).length) :
    (range' n m step)[i] = n + step * i :=
  (getElem?_eq_some_iff.1 <| getElem?_range' (by simpa using H)).2

theorem head?_range' : (range' s n).head? = if n = 0 then none else some s := by
  induction n <;> simp_all [range'_succ]

@[simp, grind =] theorem head_range' (h) : (range' s n).head h = s := by
  repeat simp_all [head?_range', head_eq_iff_head?_eq_some]

theorem map_add_range' {a} : ∀ s n step, map (a + ·) (range' s n step) = range' (a + s) n step
  | _, 0, _ => rfl
  | s, n + 1, step => by simp [range', map_add_range' (s + step), Nat.add_assoc]

theorem range'_succ_left : range' (s + 1) n step = (range' s n step).map (· + 1) := by
  apply ext_getElem
  · simp
  · simp [Nat.add_right_comm]

theorem range'_append : ∀ {s m n step : Nat},
    range' s m step ++ range' (s + step * m) n step = range' s (m + n) step
  | _, 0, _, _ => by simp
  | s, m + 1, n, step => by
    simpa [range', Nat.mul_succ, Nat.add_assoc, Nat.add_comm]
      using range'_append (s := s + step)

@[simp, grind =] theorem range'_append_1 {s m n : Nat} :
    range' s m ++ range' (s + m) n = range' s (m + n) := by simpa using range'_append (step := 1)

theorem range'_sublist_right {s m n : Nat} : range' s m step <+ range' s n step ↔ m ≤ n :=
  ⟨fun h => by simpa only [length_range'] using h.length_le,
   fun h => by rw [← add_sub_of_le h, ← range'_append]; apply sublist_append_left⟩

theorem range'_subset_right {s m n : Nat} (step0 : 0 < step) :
    range' s m step ⊆ range' s n step ↔ m ≤ n := by
  refine ⟨fun h => Nat.le_of_not_lt fun hn => ?_, fun h => (range'_sublist_right.2 h).subset⟩
  have ⟨i, h', e⟩ := mem_range'.1 <| h <| mem_range'.2 ⟨_, hn, rfl⟩
  exact Nat.ne_of_gt h' (Nat.eq_of_mul_eq_mul_left step0 (Nat.add_left_cancel e))

theorem range'_subset_right_1 {s m n : Nat} : range' s m ⊆ range' s n ↔ m ≤ n :=
  range'_subset_right (by decide)

theorem range'_concat {s n : Nat} : range' s (n + 1) step = range' s n step ++ [s + step * n] := by
  exact range'_append.symm

theorem range'_1_concat {s n : Nat} : range' s (n + 1) = range' s n ++ [s + n] := by
  simp [range'_concat]

/-! ### range -/

@[simp, grind =] theorem range_one : range 1 = [0] := rfl

theorem range_loop_range' : ∀ s n, range.loop s (range' s n) = range' 0 (n + s)
  | 0, _ => rfl
  | s + 1, n => by rw [← Nat.add_assoc, Nat.add_right_comm n s 1]; exact range_loop_range' s (n + 1)

theorem range_eq_range' {n : Nat} : range n = range' 0 n :=
  (range_loop_range' n 0).trans <| by rw [Nat.zero_add]

theorem getElem?_range {i n : Nat} (h : i < n) : (range n)[i]? = some i := by
  simp [range_eq_range', getElem?_range' h]

@[simp, grind =] theorem getElem_range (h : j < (range n).length) : (range n)[j] = j := by
  simp [range_eq_range']

theorem range_succ_eq_map {n : Nat} : range (n + 1) = 0 :: map succ (range n) := by
  rw [range_eq_range', range_eq_range', range', Nat.add_comm, ← map_add_range']
  congr; exact funext (Nat.add_comm 1)

theorem range'_eq_map_range {s n : Nat} : range' s n = map (s + ·) (range n) := by
  rw [range_eq_range', map_add_range']; rfl

@[simp, grind =] theorem length_range {n : Nat} : (range n).length = n := by
  simp only [range_eq_range', length_range']

@[simp, grind =] theorem range_eq_nil {n : Nat} : range n = [] ↔ n = 0 := by
  rw [← length_eq_zero_iff, length_range]

theorem range_ne_nil {n : Nat} : range n ≠ [] ↔ n ≠ 0 := by
  cases n <;> simp

@[simp, grind =] theorem tail_range : (range n).tail = range' 1 (n - 1) := by
  rw [range_eq_range', tail_range']

@[simp, grind =]
theorem range_sublist {m n : Nat} : range m <+ range n ↔ m ≤ n := by
  simp only [range_eq_range', range'_sublist_right]

@[simp, grind =]
theorem range_subset {m n : Nat} : range m ⊆ range n ↔ m ≤ n := by
  simp only [range_eq_range', range'_subset_right, lt_succ_self]

theorem range_succ {n : Nat} : range (succ n) = range n ++ [n] := by
  simp only [range_eq_range', range'_1_concat, Nat.zero_add]

theorem range_add {n m : Nat} : range (n + m) = range n ++ (range m).map (n + ·) := by
  rw [← range'_eq_map_range]
  simpa [range_eq_range', Nat.add_comm] using (range'_append_1 (s := 0)).symm

theorem head?_range {n : Nat} : (range n).head? = if n = 0 then none else some 0 := by
  induction n with
  | zero => simp
  | succ n ih =>
    simp only [range_succ, head?_append, ih]
    split <;> simp_all

@[simp, grind =] theorem head_range {n : Nat} (h) : (range n).head h = 0 := by
  cases n with
  | zero => simp at h
  | succ n => simp [head?_range, head_eq_iff_head?_eq_some]

theorem getLast?_range {n : Nat} : (range n).getLast? = if n = 0 then none else some (n - 1) := by
  induction n with
  | zero => simp
  | succ n ih =>
    simp only [range_succ, getLast?_append, ih]
    split <;> simp_all

@[simp, grind =] theorem getLast_range {n : Nat} (h) : (range n).getLast h = n - 1 := by
  cases n with
  | zero => simp at h
  | succ n => simp [getLast?_range, getLast_eq_iff_getLast?_eq_some]

/-! ### zipIdx -/

@[simp]
theorem zipIdx_eq_nil_iff {l : List α} {i : Nat} : List.zipIdx l i = [] ↔ l = [] := by
  cases l <;> simp

@[simp] theorem length_zipIdx : ∀ {l : List α} {i}, (zipIdx l i).length = l.length
  | [], _ => rfl
  | _ :: _, _ => congrArg Nat.succ length_zipIdx

@[simp, grind =]
theorem getElem?_zipIdx :
    ∀ {l : List α} {i j}, (zipIdx l i)[j]? = l[j]?.map fun a => (a, i + j)
  | [], _, _ => rfl
  | _ :: _, _, 0 => by simp
  | _ :: l, n, m + 1 => by
    simp only [zipIdx_cons, getElem?_cons_succ]
    exact getElem?_zipIdx.trans <| by rw [Nat.add_right_comm]; rfl

@[simp, grind =]
theorem getElem_zipIdx {l : List α} (h : i < (l.zipIdx j).length) :
    (l.zipIdx j)[i] = (l[i]'(by simpa [length_zipIdx] using h), j + i) := by
  simp only [length_zipIdx] at h
  rw [getElem_eq_getElem?_get]
  simp only [getElem?_zipIdx, getElem?_eq_getElem h]
  simp

@[simp, grind =]
theorem tail_zipIdx {l : List α} {i : Nat} : (zipIdx l i).tail = zipIdx l.tail (i + 1) := by
  induction l generalizing i with
  | nil => simp
  | cons _ l ih => simp [zipIdx_cons]

theorem map_snd_add_zipIdx_eq_zipIdx {l : List α} {n k : Nat} :
    map (Prod.map id (· + n)) (zipIdx l k) = zipIdx l (n + k) :=
  ext_getElem? fun i ↦ by simp [Nat.add_comm, Nat.add_left_comm]; rfl

theorem zipIdx_cons' {i : Nat} {x : α} {xs : List α} :
    zipIdx (x :: xs) i = (x, i) :: (zipIdx xs i).map (Prod.map id (· + 1)) := by
  rw [zipIdx_cons, Nat.add_comm, ← map_snd_add_zipIdx_eq_zipIdx]

-- Arguments are explicit for parity with `zipIdx_map_fst`.
@[simp]
theorem zipIdx_map_snd (i : Nat) :
    ∀ (l : List α), map Prod.snd (zipIdx l i) = range' i l.length
  | [] => rfl
  | _ :: _ => congrArg (cons _) (zipIdx_map_snd ..)

-- Arguments are explicit so we can rewrite from right to left.
@[simp]
theorem zipIdx_map_fst : ∀ i (l : List α), map Prod.fst (zipIdx l i) = l
  | _, [] => rfl
  | _, _ :: _ => congrArg (cons _) (zipIdx_map_fst ..)

theorem zipIdx_eq_zip_range' {l : List α} {i : Nat} : l.zipIdx i = l.zip (range' i l.length) :=
  zip_of_prod (zipIdx_map_fst ..) (zipIdx_map_snd ..)

@[simp]
theorem unzip_zipIdx_eq_prod {l : List α} {i : Nat} :
    (l.zipIdx i).unzip = (l, range' i l.length) := by
  simp only [zipIdx_eq_zip_range', unzip_zip, length_range']

/-- Replace `zipIdx` with a starting index `n+1` with `zipIdx` starting from `n`,
followed by a `map` increasing the indices by one. -/
theorem zipIdx_succ {l : List α} {i : Nat} :
    l.zipIdx (i + 1) = (l.zipIdx i).map (fun ⟨a, i⟩ => (a, i + 1)) := by
  induction l generalizing i with
  | nil => rfl
  | cons _ _ ih => simp only [zipIdx_cons, ih, map_cons]

/-- Replace `zipIdx` with a starting index with `zipIdx` starting from 0,
followed by a `map` increasing the indices. -/
theorem zipIdx_eq_map_add {l : List α} {i : Nat} :
    l.zipIdx i = l.zipIdx.map (fun ⟨a, j⟩ => (a, i + j)) := by
  induction l generalizing i with
  | nil => rfl
  | cons _ _ ih => simp [ih (i := i + 1), zipIdx_succ, Nat.add_assoc, Nat.add_comm 1]

end List
